{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip\n",
    "# !unzip uncased_L-12_H-768_A-12.zip\n",
    "# !wget https://raw.githubusercontent.com/huseinzol05/NLP-Models-Tensorflow/master/abstractive-summarization/dataset/ctexts.json\n",
    "# !wget https://raw.githubusercontent.com/huseinzol05/NLP-Models-Tensorflow/master/abstractive-summarization/dataset/headlines.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/tf_inspect.py:75: DeprecationWarning: inspect.getargspec() is deprecated since Python 3.0, use inspect.signature() or inspect.getfullargspec()\n",
      "  return _inspect.getargspec(target)\n",
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/tf_inspect.py:75: DeprecationWarning: inspect.getargspec() is deprecated since Python 3.0, use inspect.signature() or inspect.getfullargspec()\n",
      "  return _inspect.getargspec(target)\n",
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/tf_inspect.py:75: DeprecationWarning: inspect.getargspec() is deprecated since Python 3.0, use inspect.signature() or inspect.getfullargspec()\n",
      "  return _inspect.getargspec(target)\n",
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/tf_inspect.py:75: DeprecationWarning: inspect.getargspec() is deprecated since Python 3.0, use inspect.signature() or inspect.getfullargspec()\n",
      "  return _inspect.getargspec(target)\n",
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/tf_inspect.py:75: DeprecationWarning: inspect.getargspec() is deprecated since Python 3.0, use inspect.signature() or inspect.getfullargspec()\n",
      "  return _inspect.getargspec(target)\n",
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/tf_inspect.py:75: DeprecationWarning: inspect.getargspec() is deprecated since Python 3.0, use inspect.signature() or inspect.getfullargspec()\n",
      "  return _inspect.getargspec(target)\n",
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/tf_inspect.py:75: DeprecationWarning: inspect.getargspec() is deprecated since Python 3.0, use inspect.signature() or inspect.getfullargspec()\n",
      "  return _inspect.getargspec(target)\n",
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/tf_inspect.py:75: DeprecationWarning: inspect.getargspec() is deprecated since Python 3.0, use inspect.signature() or inspect.getfullargspec()\n",
      "  return _inspect.getargspec(target)\n"
     ]
    }
   ],
   "source": [
    "import bert\n",
    "from bert import run_classifier\n",
    "from bert import optimization\n",
    "from bert import tokenization\n",
    "from bert import modeling\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensor2tensor.utils import beam_search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from unidecode import unidecode\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('ctexts.json','r') as fopen:\n",
    "    ctexts = json.load(fopen)\n",
    "    \n",
    "with open('headlines.json','r') as fopen:\n",
    "    headlines = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "BERT_VOCAB = 'uncased_L-12_H-768_A-12/vocab.txt'\n",
    "BERT_INIT_CHKPNT = 'uncased_L-12_H-768_A-12/bert_model.ckpt'\n",
    "BERT_CONFIG = 'uncased_L-12_H-768_A-12/bert_config.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/usr/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "from sklearn.decomposition import TruncatedSVD\n",
    "\n",
    "def topic_modelling(string, n = 200):\n",
    "    vectorizer = TfidfVectorizer()\n",
    "    tf = vectorizer.fit_transform([string])\n",
    "    tf_features = vectorizer.get_feature_names()\n",
    "    compose = TruncatedSVD(1).fit(tf)\n",
    "    return ' '.join([tf_features[i] for i in compose.components_[0].argsort()[: -n - 1 : -1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter(\"ignore\", category=PendingDeprecationWarning)\n",
    "warnings.simplefilter(\"ignore\", category=RuntimeWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 37.5 s, sys: 51.9 s, total: 1min 29s\n",
      "Wall time: 22.5 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "h, c = [], []\n",
    "for i in range(len(ctexts)):\n",
    "    try:\n",
    "        c.append(topic_modelling(ctexts[i]))\n",
    "        h.append(headlines[i])\n",
    "    except:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import itertools\n",
    "\n",
    "def build_dataset(words, n_words, atleast=1):\n",
    "    count = [['PAD', 0], ['GO', 1], ['EOS', 2], ['UNK', 3]]\n",
    "    counter = collections.Counter(words).most_common(n_words)\n",
    "    counter = [i for i in counter if i[1] >= atleast]\n",
    "    count.extend(counter)\n",
    "    dictionary = dict()\n",
    "    for word, _ in count:\n",
    "        dictionary[word] = len(dictionary)\n",
    "    data = list()\n",
    "    unk_count = 0\n",
    "    for word in words:\n",
    "        index = dictionary.get(word, 0)\n",
    "        if index == 0:\n",
    "            unk_count += 1\n",
    "        data.append(index)\n",
    "    count[0][1] = unk_count\n",
    "    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))\n",
    "    return data, count, dictionary, reversed_dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 8534\n",
      "Most common words [('to', 1388), ('in', 1196), ('comma', 876), ('s', 785), ('for', 733), ('of', 596)]\n",
      "Sample data [2797, 14, 2798, 2799, 656, 2800, 5, 1642, 657, 2086] ['daman', 'and', 'diu', 'revokes', 'mandatory', 'rakshabandhan', 'in', 'offices', 'order', 'malaika']\n",
      "filtered vocab size: 8538\n",
      "% of vocab used: 100.05%\n"
     ]
    }
   ],
   "source": [
    "concat = ' '.join(h).split()\n",
    "vocabulary_size = len(list(set(concat)))\n",
    "data, count, dictionary, rev_dictionary = build_dataset(concat, vocabulary_size)\n",
    "print('vocab from size: %d'%(vocabulary_size))\n",
    "print('Most common words', count[4:10])\n",
    "print('Sample data', data[:10], [rev_dictionary[i] for i in data[:10]])\n",
    "print('filtered vocab size:',len(dictionary))\n",
    "print(\"% of vocab used: {}%\".format(round(len(dictionary)/vocabulary_size,4)*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_SEQ_LENGTH = 200\n",
    "tokenization.validate_case_matches_checkpoint(True,BERT_INIT_CHKPNT)\n",
    "tokenizer = tokenization.FullTokenizer(\n",
    "      vocab_file=BERT_VOCAB, do_lower_case=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4395/4395 [00:14<00:00, 295.96it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "input_ids, input_masks, segment_ids = [], [], []\n",
    "\n",
    "for text in tqdm(c):\n",
    "    tokens_a = tokenizer.tokenize(text)\n",
    "    if len(tokens_a) > MAX_SEQ_LENGTH - 2:\n",
    "        tokens_a = tokens_a[:(MAX_SEQ_LENGTH - 2)]\n",
    "    tokens = [\"[CLS]\"] + tokens_a + [\"[SEP]\"]\n",
    "    segment_id = [0] * len(tokens)\n",
    "    input_id = tokenizer.convert_tokens_to_ids(tokens)\n",
    "    input_mask = [1] * len(input_id)\n",
    "    padding = [0] * (MAX_SEQ_LENGTH - len(input_id))\n",
    "    input_id += padding\n",
    "    input_mask += padding\n",
    "    segment_id += padding\n",
    "    \n",
    "    input_ids.append(input_id)\n",
    "    input_masks.append(input_mask)\n",
    "    segment_ids.append(segment_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "GO = dictionary['GO']\n",
    "PAD = dictionary['PAD']\n",
    "EOS = dictionary['EOS']\n",
    "UNK = dictionary['UNK']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'daman and diu revokes mandatory rakshabandhan in offices order EOS'"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for i in range(len(h)):\n",
    "    h[i] = h[i] + ' EOS'\n",
    "h[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG)\n",
    "epoch = 20\n",
    "batch_size = 5\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = int(len(input_ids) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def embed_seq(x, vocab_sz, embed_dim, name, zero_pad=True): \n",
    "    embedding = tf.get_variable(name, [vocab_sz, embed_dim]) \n",
    "    if zero_pad:\n",
    "        embedding = tf.concat([tf.zeros([1, embed_dim]), embedding[1:, :]], 0) \n",
    "    x = tf.nn.embedding_lookup(embedding, x)\n",
    "    return x\n",
    "\n",
    "def position_encoding(inputs):\n",
    "    T = tf.shape(inputs)[1]\n",
    "    repr_dim = inputs.get_shape()[-1].value\n",
    "    pos = tf.reshape(tf.range(0.0, tf.to_float(T), dtype=tf.float32), [-1, 1])\n",
    "    i = np.arange(0, repr_dim, 2, np.float32)\n",
    "    denom = np.reshape(np.power(10000.0, i / repr_dim), [1, -1])\n",
    "    enc = tf.expand_dims(tf.concat([tf.sin(pos / denom), tf.cos(pos / denom)], 1), 0)\n",
    "    return tf.tile(enc, [tf.shape(inputs)[0], 1, 1])\n",
    "\n",
    "def layer_norm(inputs, epsilon=1e-8):\n",
    "    mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)\n",
    "    normalized = (inputs - mean) / (tf.sqrt(variance + epsilon))\n",
    "    params_shape = inputs.get_shape()[-1:]\n",
    "    gamma = tf.get_variable('gamma', params_shape, tf.float32, tf.ones_initializer())\n",
    "    beta = tf.get_variable('beta', params_shape, tf.float32, tf.zeros_initializer())\n",
    "    return gamma * normalized + beta\n",
    "\n",
    "\n",
    "def cnn_block(x, dilation_rate, pad_sz, hidden_dim, kernel_size):\n",
    "    x = layer_norm(x)\n",
    "    pad = tf.zeros([tf.shape(x)[0], pad_sz, hidden_dim])\n",
    "    x =  tf.layers.conv1d(inputs = tf.concat([pad, x, pad], 1),\n",
    "                          filters = hidden_dim,\n",
    "                          kernel_size = kernel_size,\n",
    "                          dilation_rate = dilation_rate)\n",
    "    x = x[:, :-pad_sz, :]\n",
    "    x = tf.nn.relu(x)\n",
    "    return x\n",
    "\n",
    "class Translator:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, \n",
    "                 to_dict_size, learning_rate, kernel_size = 2, n_attn_heads = 16):\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype=tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype=tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        self.batch_size = batch_size\n",
    "        \n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        self.embedding = tf.Variable(tf.random_uniform([to_dict_size, embedded_size], -1, 1))\n",
    "        \n",
    "        self.num_layers = num_layers\n",
    "        self.kernel_size = kernel_size\n",
    "        self.size_layer = size_layer\n",
    "        self.n_attn_heads = n_attn_heads\n",
    "        self.dict_size = to_dict_size\n",
    "        \n",
    "        self.training_logits = self.forward(self.X, decoder_input)\n",
    "\n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)\n",
    "        y_t = tf.argmax(self.training_logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n",
    "        \n",
    "    def forward(self, x, y, reuse = False):\n",
    "        with tf.variable_scope('forward',reuse=reuse):\n",
    "            with tf.variable_scope('forward',reuse=reuse):\n",
    "                model = modeling.BertModel(\n",
    "                config=bert_config,\n",
    "                is_training=True,\n",
    "                input_ids=x,\n",
    "                use_one_hot_embeddings=False)\n",
    "                encoder_embedded = model.get_sequence_output()\n",
    "                decoder_embedded = tf.nn.embedding_lookup(self.embedding, y)\n",
    "                \n",
    "                g = tf.identity(decoder_embedded)\n",
    "                for i in range(self.num_layers):\n",
    "                    dilation_rate = 2 ** i\n",
    "                    pad_sz = (self.kernel_size - 1) * dilation_rate\n",
    "                    with tf.variable_scope('decode_%d'%i,reuse=reuse):\n",
    "                        attn_res = h = cnn_block(decoder_embedded, dilation_rate, \n",
    "                                                 pad_sz, self.size_layer, self.kernel_size)\n",
    "                        C = []\n",
    "                        for j in range(self.n_attn_heads):\n",
    "                            h_ = tf.layers.dense(h, self.size_layer//self.n_attn_heads)\n",
    "                            g_ = tf.layers.dense(g, self.size_layer//self.n_attn_heads)\n",
    "                            zu_ = tf.layers.dense(encoder_embedded, self.size_layer//self.n_attn_heads)\n",
    "                            ze_ = tf.layers.dense(encoder_embedded, self.size_layer//self.n_attn_heads)\n",
    "\n",
    "                            d = tf.layers.dense(h_, self.size_layer//self.n_attn_heads) + g_\n",
    "                            dz = tf.matmul(d, tf.transpose(zu_, [0, 2, 1]))\n",
    "                            a = tf.nn.softmax(dz)\n",
    "                            c_ = tf.matmul(a, ze_)\n",
    "                            C.append(c_)\n",
    "\n",
    "                        c = tf.concat(C, 2)\n",
    "                        h = tf.layers.dense(attn_res + c, self.size_layer)\n",
    "                        decoder_embedded += h\n",
    "\n",
    "                return tf.layers.dense(decoder_embedded, self.dict_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 256\n",
    "num_layers = 4\n",
    "embedded_size = 256\n",
    "learning_rate = 2e-5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def beam_search_decoding(length = 20, beam_width = 5):\n",
    "    initial_ids = tf.fill([model.batch_size], GO)\n",
    "    \n",
    "    def symbols_to_logits(ids):\n",
    "        x = tf.contrib.seq2seq.tile_batch(model.X, beam_width)\n",
    "        logits = model.forward(x, ids, reuse = True)\n",
    "        return logits[:, tf.shape(ids)[1]-1, :]\n",
    "\n",
    "    final_ids, final_probs, _ = beam_search.beam_search(\n",
    "        symbols_to_logits,\n",
    "        initial_ids,\n",
    "        beam_width,\n",
    "        length,\n",
    "        len(dictionary),\n",
    "        0.0,\n",
    "        eos_id = EOS)\n",
    "    \n",
    "    return final_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n",
      "  'a.item() instead', DeprecationWarning, stacklevel=1)\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Translator(size_layer, num_layers, embedded_size, \n",
    "                len(dictionary), learning_rate)\n",
    "model.generate = beam_search_decoding()\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_assignment_map_from_checkpoint(tvars, init_checkpoint):\n",
    "    \"\"\"Compute the union of the current variables and checkpoint variables.\"\"\"\n",
    "    assignment_map = {}\n",
    "    initialized_variable_names = {}\n",
    "\n",
    "    name_to_variable = collections.OrderedDict()\n",
    "    for var in tvars:\n",
    "        name = var.name\n",
    "        m = re.match('^(.*):\\\\d+$', name)\n",
    "        if m is not None:\n",
    "            name = m.group(1)\n",
    "        name_to_variable[name] = var\n",
    "\n",
    "    init_vars = tf.train.list_variables(init_checkpoint)\n",
    "\n",
    "    assignment_map = collections.OrderedDict()\n",
    "    for x in init_vars:\n",
    "        (name, var) = (x[0], x[1])\n",
    "        if 'bert' not in name:\n",
    "            continue\n",
    "        assignment_map[name] = name_to_variable['forward/forward/'+name]\n",
    "        initialized_variable_names[name] = 1\n",
    "        initialized_variable_names[name + ':0'] = 1\n",
    "\n",
    "    return (assignment_map, initialized_variable_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "tvars = tf.trainable_variables()\n",
    "assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, \n",
    "                                                                                BERT_INIT_CHKPNT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from uncased_L-12_H-768_A-12/bert_model.ckpt\n"
     ]
    }
   ],
   "source": [
    "saver = tf.train.Saver(var_list = assignment_map)\n",
    "saver.restore(sess, BERT_INIT_CHKPNT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def str_idx(corpus, dic):\n",
    "    X = []\n",
    "    for i in corpus:\n",
    "        ints = []\n",
    "        for k in i.split():\n",
    "            ints.append(dic.get(k,UNK))\n",
    "        X.append(ints)\n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = str_idx(h, dictionary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.cross_validation import train_test_split\n",
    "\n",
    "train_X, test_X, train_Y, test_Y = train_test_split(input_ids, Y, test_size = 0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_sentence_batch(sentence_batch, pad_int):\n",
    "    padded_seqs = []\n",
    "    seq_lens = []\n",
    "    max_sentence_len = max([len(sentence) for sentence in sentence_batch])\n",
    "    for sentence in sentence_batch:\n",
    "        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))\n",
    "        seq_lens.append(len(sentence))\n",
    "    return padded_seqs, seq_lens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[   1, 4700, 1519, 4183, 1519, 6135, 1519,  725,  982, 1045,\n",
       "         7026, 8370, 1519,   42, 2505, 5440,  567, 2025, 6334, 7626,\n",
       "         6152],\n",
       "        [   1, 4700, 1519, 4183, 1519, 6135, 1519,  725,  982, 1045,\n",
       "         7026, 8370, 1519,   42, 2505, 5440,  567, 2025, 6334, 4224,\n",
       "          932],\n",
       "        [   1, 4700, 1519, 4183, 1519, 6135, 1519,  725,  982, 1045,\n",
       "         7026, 8370, 1519,   42, 2505, 5440,  567, 5651, 1519, 5578,\n",
       "         5423],\n",
       "        [   1, 4700, 1519, 4183, 1519, 6135, 1519,  725,  982, 1045,\n",
       "         7026, 8370, 1519,   42, 2505, 5440,  567, 5651, 1519, 5578,\n",
       "          434],\n",
       "        [   1, 4700, 1519, 4183, 1519, 6135, 1519,  725,  982, 1045,\n",
       "         7026, 8370, 1519,   42, 2505, 5440,  567, 5651, 1519, 8244,\n",
       "         3935]]], dtype=int32)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sess.run(model.generate, feed_dict = {model.X: [test_X[0]]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 835/835 [09:47<00:00,  1.42it/s, accuracy=0.123, cost=7.57] \n",
      "test minibatch loop: 100%|██████████| 44/44 [00:13<00:00,  3.80it/s, accuracy=0.135, cost=7.56] \n",
      "train minibatch loop:   0%|          | 0/835 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 600.4188756942749\n",
      "epoch: 0, training loss: 7.504465, training acc: 0.103313, valid loss: 7.265235, valid acc: 0.113662\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 835/835 [09:43<00:00,  1.42it/s, accuracy=0.105, cost=7.22] \n",
      "test minibatch loop: 100%|██████████| 44/44 [00:11<00:00,  3.90it/s, accuracy=0.115, cost=7.46] \n",
      "train minibatch loop:   0%|          | 0/835 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 595.078351020813\n",
      "epoch: 1, training loss: 6.944261, training acc: 0.116727, valid loss: 7.155360, valid acc: 0.119874\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 835/835 [09:44<00:00,  1.43it/s, accuracy=0.0877, cost=6.82]\n",
      "test minibatch loop: 100%|██████████| 44/44 [00:11<00:00,  3.95it/s, accuracy=0.135, cost=7.37] \n",
      "train minibatch loop:   0%|          | 0/835 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 595.7576084136963\n",
      "epoch: 2, training loss: 6.654561, training acc: 0.128684, valid loss: 7.047701, valid acc: 0.127980\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 835/835 [09:43<00:00,  1.44it/s, accuracy=0.105, cost=6.4]  \n",
      "test minibatch loop: 100%|██████████| 44/44 [00:11<00:00,  3.71it/s, accuracy=0.135, cost=7.27] \n",
      "train minibatch loop:   0%|          | 0/835 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 594.9749147891998\n",
      "epoch: 3, training loss: 6.328232, training acc: 0.140889, valid loss: 6.963283, valid acc: 0.132893\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 835/835 [09:44<00:00,  1.43it/s, accuracy=0.123, cost=6.08] \n",
      "test minibatch loop: 100%|██████████| 44/44 [00:11<00:00,  3.98it/s, accuracy=0.135, cost=7.19] \n",
      "train minibatch loop:   0%|          | 0/835 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 595.5642290115356\n",
      "epoch: 4, training loss: 5.996441, training acc: 0.154711, valid loss: 6.916960, valid acc: 0.137149\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 835/835 [09:43<00:00,  1.43it/s, accuracy=0.123, cost=5.73] \n",
      "test minibatch loop: 100%|██████████| 44/44 [00:11<00:00,  3.90it/s, accuracy=0.115, cost=7.17] \n",
      "train minibatch loop:   0%|          | 0/835 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 595.1420586109161\n",
      "epoch: 5, training loss: 5.667774, training acc: 0.168648, valid loss: 6.879445, valid acc: 0.141117\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 835/835 [09:43<00:00,  1.42it/s, accuracy=0.211, cost=5.23] \n",
      "test minibatch loop: 100%|██████████| 44/44 [00:11<00:00,  3.70it/s, accuracy=0.0962, cost=7.14]\n",
      "train minibatch loop:   0%|          | 0/835 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 594.8970482349396\n",
      "epoch: 6, training loss: 5.350092, training acc: 0.184197, valid loss: 6.904361, valid acc: 0.140864\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 835/835 [09:43<00:00,  1.44it/s, accuracy=0.175, cost=4.89] \n",
      "test minibatch loop: 100%|██████████| 44/44 [00:11<00:00,  3.76it/s, accuracy=0.135, cost=7.18] \n",
      "train minibatch loop:   0%|          | 0/835 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 595.3370399475098\n",
      "epoch: 7, training loss: 5.039291, training acc: 0.200931, valid loss: 6.947434, valid acc: 0.142766\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 835/835 [09:43<00:00,  1.44it/s, accuracy=0.263, cost=4.62] \n",
      "test minibatch loop: 100%|██████████| 44/44 [00:11<00:00,  3.80it/s, accuracy=0.0962, cost=7.22]\n",
      "train minibatch loop:   0%|          | 0/835 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 594.4455664157867\n",
      "epoch: 8, training loss: 4.733280, training acc: 0.219349, valid loss: 7.004194, valid acc: 0.138716\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 835/835 [09:43<00:00,  1.45it/s, accuracy=0.298, cost=4.3] \n",
      "test minibatch loop: 100%|██████████| 44/44 [00:11<00:00,  3.80it/s, accuracy=0.0769, cost=7.38]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 594.7499196529388\n",
      "epoch: 9, training loss: 4.430527, training acc: 0.243508, valid loss: 7.084307, valid acc: 0.135056\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "for EPOCH in range(10):\n",
    "    lasttime = time.time()\n",
    "\n",
    "    train_acc, train_loss, test_acc, test_loss = 0, 0, 0, 0\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_X))\n",
    "        batch_x = train_X[i : index]\n",
    "        batch_y, _ = pad_sentence_batch(train_Y[i : index], PAD)\n",
    "        acc, cost, _ = sess.run(\n",
    "            [model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        train_loss += cost\n",
    "        train_acc += acc\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "        \n",
    "    pbar = tqdm(range(0, len(test_X), batch_size), desc = 'test minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_X))\n",
    "        batch_x = test_X[i : index]\n",
    "        batch_y, _ = pad_sentence_batch(test_Y[i : index], PAD)\n",
    "        acc, cost = sess.run(\n",
    "            [model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x\n",
    "            },\n",
    "        )\n",
    "        test_loss += cost\n",
    "        test_acc += acc\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "\n",
    "    train_loss /= len(train_X) / batch_size\n",
    "    train_acc /= len(train_X) / batch_size\n",
    "    test_loss /= len(test_X) / batch_size\n",
    "    test_acc /= len(test_X) / batch_size\n",
    "        \n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (EPOCH, train_loss, train_acc, test_loss, test_acc)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'GO govt bans to set in maharashtra EOS PAD'"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generated = [rev_dictionary[i] for i in sess.run(model.generate, feed_dict = {model.X: [test_X[0]]})[0,0,:]]\n",
    "' '.join(generated)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1l for pregnant women to expose sex determination centres EOS'"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "' '.join([rev_dictionary[i] for i in test_Y[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
